# -*- coding:utf-8 -*- 
#Author: OCEAN
#
#                            _ooOoo_
#                           o8888888o
#                           88" . "88
#                           (| -_- |)
#                           O\  =  /O
#                        ____/`---'\____
#                      .'  \\|     |//  `.
#                     /  \\|||  :  |||//  \
#                    /  _||||| -:- |||||-  \
#                    |   | \\\  -  /// |   |
#                    | \_|  ''\---/''  |   |
#                    \  .-\__  `-`  ___/-. /
#                  ___`. .'  /--.--\  `. . __
#               ."" '<  `.___\_<|>_/___.'  >'"".
#              | | :  `- \`.;`\ _ /`;.`/ - ` : | |
#              \  \ `-.   \_ __\ /__ _/   .-` /  /
#         ======`-.____`-.___\_____/___.-`____.-'======
#                            `=---='
#        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#                    Buddha-like Programming......
import torch
import math
import itertools
import cv2
import numpy as np

class DataEncoder:
    def __init__(self):
        pass

    def iou(self, box1, box2):
        """
        计算两个box的交并比 intersection over union(IOU),
        each box is [x1,y1,x2,y2].
        :param box1: (tensor) bounding boxes, sized [N,4].
        :param box2: (tensor) bounding boxes, sized [M,4].
        :return: (tensor) iou, sized [N,M].
        """
        N = box1.size(0)
        M = box2.size(0)

        lt = torch.max(  # left top
            box1[:, :2].unsqueeze(1).expand(N, M, 2),  # [N,2] -> [N,1,2] -> [N,M,2]
            box2[:, :2].unsqueeze(0).expand(N, M, 2),  # [M,2] -> [1,M,2] -> [N,M,2]
        )

        rb = torch.min(  # right bottom
            box1[:, 2:].unsqueeze(1).expand(N, M, 2),  # [N,2] -> [N,1,2] -> [N,M,2]
            box2[:, 2:].unsqueeze(0).expand(N, M, 2),  # [M,2] -> [1,M,2] -> [N,M,2]
        )

        wh = rb - lt  # [N,M,2]
        wh[wh < 0] = 0  # clip at 0
        inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

        area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])  # [N,]
        area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])  # [M,]
        area1 = area1.unsqueeze(1).expand_as(inter)  # [N,] -> [N,1] -> [N,M]
        area2 = area2.unsqueeze(0).expand_as(inter)  # [M,] -> [1,M] -> [N,M]

        iou = inter / (area1 + area2 - inter)
        return iou

    def test_iou(self):
        box1 = torch.IntTensor([0, 0, 10, 10])
        box1 = box1[None, :]
        box2 = torch.IntTensor([[5, 0, 15, 10], [5, 0, 15, 10]])
        print(box1.shape, box2.shape)
        print('iou', self.iou(box1, box2))

    def encode(self, boxes, classes, threshold=0.35):
        """

        :param boxes: [num_obj, 4] default_box (x1,y1,x2,y2)
        :param classes: (tensor) [num_obj, 21824, 4]
        :param threshold:  class label [obj,]
        :return:
        """
        pass

    def test_encode(self, boxes, img, label):
        pass

    def nms(self, bboxes, scores, threshold=0.5):
        pass

    def decode(self, loc, conf):
        """
        将预测出的 loc/conf转换成真实的人脸框
        :param loc:
        :param conf:
        :return:
        """
        pass


if __name__ == '__main__':
    dataencoder = DataEncoder()
    dataencoder.test_iou()